In [1]:
from time import time

import torch
from torchvision import datasets, transforms
from tqdm import tqdm

from model_bruno import *
from utils import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
transformation = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip()      
])

STL10_set = datasets.STL10(root='C:/Users/bfons/OneDrive/Documents/Datasets/STL10', split='train', transform=transformation, download=False)
STL10_loader = torch.utils.data.DataLoader(STL10_set, batch_size=1, shuffle=True,  pin_memory=True, num_workers=6)

In [3]:
num_epochs = 100
num_clusters = 10
num_classes = 10
area_penalty = False

model = AttModel(n_classes=num_classes, n_clusters=num_clusters)
model = model.to(device)
opt = torch.optim.Adam(model.parameters(), 1e-4)
criterion = nn.CrossEntropyLoss()



In [4]:
model.train()
for epoch in range(num_epochs):
    tic = time()
    avg_loss = 0
    avg_acc = 0
    for x,y in tqdm(STL10_loader):
        x = x.to(device)
        y = y.to(device)

        pred, _, _ = model(x)
        loss = criterion(pred.view(1,-1),y)        
        # if area_penalty:
        #     loss += torch.sum((bboxes[:,-1]*bboxes[:,-2]))/num_clusters**2

        opt.zero_grad()
        loss.backward(retain_graph=False)        
        opt.step()
        
        avg_loss += float(loss) / len(STL10_loader)
        avg_acc += (y == pred.argmax(1)).float().mean() / len(STL10_loader)   
    
    toc = time()
    
    print(f'Epoch {epoch+1}/{num_epochs} - {toc-tic:.0f}s - Avg loss: {avg_loss} - Avg acc: {avg_acc}')

100%|██████████| 5000/5000 [16:35<00:00,  5.02it/s]


Epoch 1/100 - 995s - Avg loss: 2.057565796540679 - Avg acc: 0.19260211288928986


100%|██████████| 5000/5000 [16:33<00:00,  5.03it/s]


Epoch 2/100 - 993s - Avg loss: 1.7781131371051082 - Avg acc: 0.2874037027359009


100%|██████████| 5000/5000 [16:51<00:00,  4.94it/s]


Epoch 3/100 - 1012s - Avg loss: 1.6691619516856973 - Avg acc: 0.3278043866157532


100%|██████████| 5000/5000 [16:46<00:00,  4.97it/s]


Epoch 4/100 - 1006s - Avg loss: 1.6026229110986037 - Avg acc: 0.3582049012184143


100%|██████████| 5000/5000 [16:45<00:00,  4.97it/s]


Epoch 5/100 - 1005s - Avg loss: 1.5555413918927274 - Avg acc: 0.3790052533149719


100%|██████████| 5000/5000 [16:45<00:00,  4.97it/s]


Epoch 6/100 - 1005s - Avg loss: 1.4972220715507805 - Avg acc: 0.4158058762550354


100%|██████████| 5000/5000 [16:47<00:00,  4.96it/s]


Epoch 7/100 - 1007s - Avg loss: 1.4397568018815006 - Avg acc: 0.446806401014328


100%|██████████| 5000/5000 [16:48<00:00,  4.96it/s]


Epoch 8/100 - 1008s - Avg loss: 1.3679167938694314 - Avg acc: 0.46800675988197327


100%|██████████| 5000/5000 [16:52<00:00,  4.94it/s]


Epoch 9/100 - 1012s - Avg loss: 1.28343326691645 - Avg acc: 0.5110058188438416


100%|██████████| 5000/5000 [16:50<00:00,  4.95it/s]


Epoch 10/100 - 1010s - Avg loss: 1.2262129393208312 - Avg acc: 0.5330029129981995


100%|██████████| 5000/5000 [16:55<00:00,  4.93it/s]


Epoch 11/100 - 1015s - Avg loss: 1.1689356350910869 - Avg acc: 0.5510005354881287


100%|██████████| 5000/5000 [16:57<00:00,  4.92it/s]


Epoch 12/100 - 1017s - Avg loss: 1.1293422840609486 - Avg acc: 0.5697980523109436


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 13/100 - 1015s - Avg loss: 1.0787007778163835 - Avg acc: 0.5921950936317444


100%|██████████| 5000/5000 [16:52<00:00,  4.94it/s]


Epoch 14/100 - 1012s - Avg loss: 1.0420211740927072 - Avg acc: 0.6031936407089233


100%|██████████| 5000/5000 [16:55<00:00,  4.93it/s]


Epoch 15/100 - 1015s - Avg loss: 1.012712971453463 - Avg acc: 0.6141921877861023


100%|██████████| 5000/5000 [16:55<00:00,  4.92it/s]


Epoch 16/100 - 1015s - Avg loss: 0.9643213918008473 - Avg acc: 0.6349894404411316


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 17/100 - 1015s - Avg loss: 0.9427976785075826 - Avg acc: 0.6525871157646179


100%|██████████| 5000/5000 [16:55<00:00,  4.92it/s]


Epoch 18/100 - 1016s - Avg loss: 0.9004605265416571 - Avg acc: 0.6559866666793823


100%|██████████| 5000/5000 [16:57<00:00,  4.91it/s]


Epoch 19/100 - 1018s - Avg loss: 0.8600008048166884 - Avg acc: 0.681983232498169


100%|██████████| 5000/5000 [16:57<00:00,  4.91it/s]


Epoch 20/100 - 1018s - Avg loss: 0.8302746159214326 - Avg acc: 0.6883823871612549


100%|██████████| 5000/5000 [16:52<00:00,  4.94it/s]


Epoch 21/100 - 1012s - Avg loss: 0.7815017929540238 - Avg acc: 0.7065799832344055


100%|██████████| 5000/5000 [16:55<00:00,  4.93it/s]


Epoch 22/100 - 1015s - Avg loss: 0.7446834722652618 - Avg acc: 0.721977949142456


100%|██████████| 5000/5000 [16:52<00:00,  4.94it/s]


Epoch 23/100 - 1012s - Avg loss: 0.7068646227520582 - Avg acc: 0.7371759414672852


100%|██████████| 5000/5000 [16:58<00:00,  4.91it/s]


Epoch 24/100 - 1019s - Avg loss: 0.6670534837545977 - Avg acc: 0.7509741187095642


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 25/100 - 1014s - Avg loss: 0.6439412832838193 - Avg acc: 0.7645723223686218


100%|██████████| 5000/5000 [16:58<00:00,  4.91it/s]


Epoch 26/100 - 1019s - Avg loss: 0.5850015805661078 - Avg acc: 0.7819700241088867


100%|██████████| 5000/5000 [16:51<00:00,  4.95it/s]


Epoch 27/100 - 1011s - Avg loss: 0.5625328691937792 - Avg acc: 0.7963681221008301


100%|██████████| 5000/5000 [16:55<00:00,  4.92it/s]


Epoch 28/100 - 1015s - Avg loss: 0.5141613081987548 - Avg acc: 0.8087664842605591


100%|██████████| 5000/5000 [17:00<00:00,  4.90it/s]


Epoch 29/100 - 1021s - Avg loss: 0.4947564304870151 - Avg acc: 0.826164186000824


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 30/100 - 1014s - Avg loss: 0.4413239006893358 - Avg acc: 0.8439618349075317


100%|██████████| 5000/5000 [16:57<00:00,  4.91it/s]


Epoch 31/100 - 1018s - Avg loss: 0.42378242716492526 - Avg acc: 0.847761332988739


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 32/100 - 1014s - Avg loss: 0.37499464110579084 - Avg acc: 0.8677586913108826


100%|██████████| 5000/5000 [16:57<00:00,  4.91it/s]


Epoch 33/100 - 1017s - Avg loss: 0.34573367444482966 - Avg acc: 0.878957211971283


100%|██████████| 5000/5000 [16:57<00:00,  4.91it/s]


Epoch 34/100 - 1017s - Avg loss: 0.3308849254335656 - Avg acc: 0.8847564458847046


100%|██████████| 5000/5000 [16:53<00:00,  4.93it/s]


Epoch 35/100 - 1014s - Avg loss: 0.2984389823632357 - Avg acc: 0.8951550722122192


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 36/100 - 1015s - Avg loss: 0.2686720109046059 - Avg acc: 0.9085533022880554


100%|██████████| 5000/5000 [17:03<00:00,  4.88it/s]


Epoch 37/100 - 1024s - Avg loss: 0.263826307699321 - Avg acc: 0.9099531173706055


100%|██████████| 5000/5000 [16:53<00:00,  4.93it/s]


Epoch 38/100 - 1013s - Avg loss: 0.2196423880778889 - Avg acc: 0.926750898361206


100%|██████████| 5000/5000 [17:03<00:00,  4.89it/s]


Epoch 39/100 - 1023s - Avg loss: 0.21154142526639358 - Avg acc: 0.9317502379417419


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 40/100 - 1014s - Avg loss: 0.18682999732533262 - Avg acc: 0.9335500001907349


100%|██████████| 5000/5000 [16:57<00:00,  4.91it/s]


Epoch 41/100 - 1017s - Avg loss: 0.17177150827127946 - Avg acc: 0.9411489963531494


100%|██████████| 5000/5000 [16:56<00:00,  4.92it/s]


Epoch 42/100 - 1017s - Avg loss: 0.17860400785240868 - Avg acc: 0.942348837852478


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 43/100 - 1014s - Avg loss: 0.1529198646342175 - Avg acc: 0.9531474113464355


100%|██████████| 5000/5000 [17:00<00:00,  4.90it/s]


Epoch 44/100 - 1020s - Avg loss: 0.15788094158619032 - Avg acc: 0.948747992515564


100%|██████████| 5000/5000 [16:59<00:00,  4.90it/s]


Epoch 45/100 - 1020s - Avg loss: 0.13661208589956514 - Avg acc: 0.9549471735954285


100%|██████████| 5000/5000 [17:00<00:00,  4.90it/s]


Epoch 46/100 - 1021s - Avg loss: 0.13121291547497527 - Avg acc: 0.9593465924263


100%|██████████| 5000/5000 [17:00<00:00,  4.90it/s]


Epoch 47/100 - 1020s - Avg loss: 0.1164516010490692 - Avg acc: 0.9617462754249573


100%|██████████| 5000/5000 [17:01<00:00,  4.89it/s]


Epoch 48/100 - 1022s - Avg loss: 0.11871419942400388 - Avg acc: 0.9619462490081787


100%|██████████| 5000/5000 [17:00<00:00,  4.90it/s]


Epoch 49/100 - 1020s - Avg loss: 0.10556047791595442 - Avg acc: 0.9663456678390503


100%|██████████| 5000/5000 [16:52<00:00,  4.94it/s]


Epoch 50/100 - 1013s - Avg loss: 0.10572969698171569 - Avg acc: 0.9677454829216003


100%|██████████| 5000/5000 [17:00<00:00,  4.90it/s]


Epoch 51/100 - 1021s - Avg loss: 0.09760664149581329 - Avg acc: 0.9679454565048218


100%|██████████| 5000/5000 [16:59<00:00,  4.90it/s]


Epoch 52/100 - 1020s - Avg loss: 0.09181144915482065 - Avg acc: 0.9705451130867004


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 53/100 - 1014s - Avg loss: 0.09011404890655682 - Avg acc: 0.9719449281692505


100%|██████████| 5000/5000 [16:53<00:00,  4.93it/s]


Epoch 54/100 - 1013s - Avg loss: 0.08854897972082071 - Avg acc: 0.9695452451705933


100%|██████████| 5000/5000 [16:55<00:00,  4.93it/s]


Epoch 55/100 - 1015s - Avg loss: 0.08240169105340213 - Avg acc: 0.9725448489189148


100%|██████████| 5000/5000 [16:56<00:00,  4.92it/s]


Epoch 56/100 - 1017s - Avg loss: 0.09383221991893276 - Avg acc: 0.9687453508377075


100%|██████████| 5000/5000 [16:56<00:00,  4.92it/s]


Epoch 57/100 - 1017s - Avg loss: 0.08299059013246839 - Avg acc: 0.9725448489189148


100%|██████████| 5000/5000 [17:00<00:00,  4.90it/s]


Epoch 58/100 - 1020s - Avg loss: 0.06883966559794151 - Avg acc: 0.9779441356658936


100%|██████████| 5000/5000 [16:58<00:00,  4.91it/s]


Epoch 59/100 - 1018s - Avg loss: 0.06646687673862166 - Avg acc: 0.9785440564155579


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 60/100 - 1014s - Avg loss: 0.07577517557138212 - Avg acc: 0.9759443998336792


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 61/100 - 1014s - Avg loss: 0.06692902711001816 - Avg acc: 0.9791439771652222


100%|██████████| 5000/5000 [16:55<00:00,  4.92it/s]


Epoch 62/100 - 1016s - Avg loss: 0.07359220906029479 - Avg acc: 0.9769442677497864


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 63/100 - 1014s - Avg loss: 0.06822083301729913 - Avg acc: 0.9789440035820007


100%|██████████| 5000/5000 [16:51<00:00,  4.94it/s]


Epoch 64/100 - 1011s - Avg loss: 0.06599121438271496 - Avg acc: 0.9789440035820007


100%|██████████| 5000/5000 [16:51<00:00,  4.95it/s]


Epoch 65/100 - 1011s - Avg loss: 0.06194057181781555 - Avg acc: 0.9821435809135437


100%|██████████| 5000/5000 [16:52<00:00,  4.94it/s]


Epoch 66/100 - 1013s - Avg loss: 0.05930641811558008 - Avg acc: 0.9799438714981079


100%|██████████| 5000/5000 [16:59<00:00,  4.90it/s]


Epoch 67/100 - 1019s - Avg loss: 0.06042293749368547 - Avg acc: 0.9831434488296509


100%|██████████| 5000/5000 [16:51<00:00,  4.94it/s]


Epoch 68/100 - 1011s - Avg loss: 0.059214052478101346 - Avg acc: 0.9805437922477722


100%|██████████| 5000/5000 [16:53<00:00,  4.93it/s]


Epoch 69/100 - 1014s - Avg loss: 0.046849284936609226 - Avg acc: 0.9837433695793152


100%|██████████| 5000/5000 [16:52<00:00,  4.94it/s]


Epoch 70/100 - 1012s - Avg loss: 0.06610399022797071 - Avg acc: 0.9803438186645508


100%|██████████| 5000/5000 [16:51<00:00,  4.94it/s]


Epoch 71/100 - 1012s - Avg loss: 0.04909370513678901 - Avg acc: 0.9847432374954224


100%|██████████| 5000/5000 [16:51<00:00,  4.94it/s]


Epoch 72/100 - 1011s - Avg loss: 0.057140137038594864 - Avg acc: 0.9805437922477722


100%|██████████| 5000/5000 [16:59<00:00,  4.90it/s]


Epoch 73/100 - 1019s - Avg loss: 0.05526414395833078 - Avg acc: 0.9831434488296509


100%|██████████| 5000/5000 [16:50<00:00,  4.95it/s]


Epoch 74/100 - 1010s - Avg loss: 0.05398460398468166 - Avg acc: 0.9845432639122009


100%|██████████| 5000/5000 [16:59<00:00,  4.91it/s]


Epoch 75/100 - 1019s - Avg loss: 0.0495215111159076 - Avg acc: 0.9841433167457581


100%|██████████| 5000/5000 [16:47<00:00,  4.96it/s]


Epoch 76/100 - 1008s - Avg loss: 0.043395727733409925 - Avg acc: 0.9865429997444153


100%|██████████| 5000/5000 [16:45<00:00,  4.97it/s]


Epoch 77/100 - 1006s - Avg loss: 0.05374884872819746 - Avg acc: 0.982743501663208


100%|██████████| 5000/5000 [16:59<00:00,  4.91it/s]


Epoch 78/100 - 1019s - Avg loss: 0.048802847487028136 - Avg acc: 0.9869429469108582


100%|██████████| 5000/5000 [16:52<00:00,  4.94it/s]


Epoch 79/100 - 1013s - Avg loss: 0.04449798277732014 - Avg acc: 0.9863430261611938


100%|██████████| 5000/5000 [16:49<00:00,  4.95it/s]


Epoch 80/100 - 1010s - Avg loss: 0.04231207070077753 - Avg acc: 0.9861430525779724


100%|██████████| 5000/5000 [16:52<00:00,  4.94it/s]


Epoch 81/100 - 1012s - Avg loss: 0.05291082284656635 - Avg acc: 0.9845432639122009


100%|██████████| 5000/5000 [16:56<00:00,  4.92it/s]


Epoch 82/100 - 1017s - Avg loss: 0.04277968316459853 - Avg acc: 0.9869429469108582


100%|██████████| 5000/5000 [17:00<00:00,  4.90it/s]


Epoch 83/100 - 1021s - Avg loss: 0.038159254103237436 - Avg acc: 0.9877428412437439


100%|██████████| 5000/5000 [16:56<00:00,  4.92it/s]


Epoch 84/100 - 1017s - Avg loss: 0.04089808912787719 - Avg acc: 0.9881427884101868


100%|██████████| 5000/5000 [16:57<00:00,  4.92it/s]


Epoch 85/100 - 1017s - Avg loss: 0.039502204663443834 - Avg acc: 0.9855431318283081


100%|██████████| 5000/5000 [16:58<00:00,  4.91it/s]


Epoch 86/100 - 1018s - Avg loss: 0.03752547373222594 - Avg acc: 0.9883427619934082


100%|██████████| 5000/5000 [16:59<00:00,  4.90it/s]


Epoch 87/100 - 1020s - Avg loss: 0.03809625358313208 - Avg acc: 0.989142656326294


100%|██████████| 5000/5000 [16:51<00:00,  4.94it/s]


Epoch 88/100 - 1012s - Avg loss: 0.0447676053170814 - Avg acc: 0.9857431054115295


100%|██████████| 5000/5000 [17:02<00:00,  4.89it/s]


Epoch 89/100 - 1022s - Avg loss: 0.043277290849965185 - Avg acc: 0.9861430525779724


100%|██████████| 5000/5000 [16:56<00:00,  4.92it/s]


Epoch 90/100 - 1016s - Avg loss: 0.0294282774186537 - Avg acc: 0.9909424185752869


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 91/100 - 1015s - Avg loss: 0.04600568844941143 - Avg acc: 0.9861430525779724


100%|██████████| 5000/5000 [16:46<00:00,  4.97it/s]


Epoch 92/100 - 1007s - Avg loss: 0.03490368006991715 - Avg acc: 0.989142656326294


100%|██████████| 5000/5000 [16:56<00:00,  4.92it/s]


Epoch 93/100 - 1016s - Avg loss: 0.03463449452271352 - Avg acc: 0.9893426299095154


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 94/100 - 1015s - Avg loss: 0.034537870785786996 - Avg acc: 0.9889426827430725


100%|██████████| 5000/5000 [16:52<00:00,  4.94it/s]


Epoch 95/100 - 1013s - Avg loss: 0.04344030027846432 - Avg acc: 0.9875428676605225


100%|██████████| 5000/5000 [16:43<00:00,  4.98it/s]


Epoch 96/100 - 1003s - Avg loss: 0.03852048129695413 - Avg acc: 0.989142656326294


100%|██████████| 5000/5000 [16:49<00:00,  4.95it/s]


Epoch 97/100 - 1009s - Avg loss: 0.03249614986841405 - Avg acc: 0.9895426034927368


100%|██████████| 5000/5000 [16:54<00:00,  4.93it/s]


Epoch 98/100 - 1015s - Avg loss: 0.03876413683477101 - Avg acc: 0.9901425242424011


100%|██████████| 5000/5000 [16:45<00:00,  4.97it/s]


Epoch 99/100 - 1006s - Avg loss: 0.033863824625843336 - Avg acc: 0.9909424185752869


100%|██████████| 5000/5000 [16:58<00:00,  4.91it/s]

Epoch 100/100 - 1018s - Avg loss: 0.03598073036049054 - Avg acc: 0.9885427355766296





In [5]:
torch.save(model.cpu(),'./model.pt')

The model appears to not converge. Maybe replace Sigmoid activations for ReLU. It is still only working with batch_size = 1. Also need to evaluate the learning rate and which features should be used for clustering.