In [None]:
import torch 
from torch import nn,optim
from torch.nn import functional as F
from torchvision import datasets,models,transforms
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [None]:
train=datasets.STL10('stl','train',transform=transforms.ToTensor())
train_loader=torch.utils.data.DataLoader(train, batch_size=256, shuffle=True, num_workers=8,pin_memory=False)

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv=nn.Sequential(*list(models.resnet18(pretrained=False).children())[:-2])
        self.avg_pool=nn.AdaptiveAvgPool2d((1,1))
        self.fc=nn.Linear(512,10)
        
    def forward(self,x):
        x_conv=self.conv(x)
        x=self.avg_pool(x_conv)
        x=torch.flatten(x, 1)
        x=self.fc(x)
        return x
    
    def get_cam(self,x):
        #x must be 1 C W H
        self.eval()
        maps=self.conv(x)
        cl=self.forward(x).max(1)[1]
        w=self.fc.weight[cl]
        cam=(w*(maps.permute(0,2,3,1))).sum(3)
        cam=F.interpolate(cam.unsqueeze(0),x.shape[2:],mode='bilinear').squeeze().unsqueeze(0)
        return cam.detach().cpu().numpy()

In [None]:
def test_cam(model):
    plt.figure(figsize=(20,8))
    for i in range(10):
        image=train[i][0]
        cam=model.get_cam(train[i][0].unsqueeze(0).cuda())
        plt.subplot(2,5,i+1)
        plt.imshow(image.permute(1,2,0))
        plt.imshow(cam[0],cmap='jet',alpha=0.3)
        plt.axis('off')
    plt.show()


In [None]:
model=Model()
model=model.cuda()

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters())

In [None]:
for epoch in range(20):  
    
    model.train()
    running_loss = 0.0
    correct,tot=0.0,0.0
    loop = tqdm(train_loader)
    for i, data in enumerate(loop):
        
        inputs, labels = data
        inputs, labels = inputs.cuda(),labels.cuda()
        
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        cl=outputs.max(1)[1]
        correct+=(cl==labels).sum()
        tot+=outputs.shape[0]

        loop.set_description(f'acc:{correct/tot}')
        
        if i%200==0:
            print(running_loss/(inputs.shape[0]*(i+1)))

    print(f'epoch: {epoch}')
    test_cam(model)
