In [None]:
import torchvision.models as models
import os
import wandb
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image

In [None]:
class ResNet18(nn.Module):
    def __init__(self, num_classes):
        super(ResNet18, self).__init__()
        self.resnet = models.resnet18(pretrained=False, num_classes=num_classes)
        self.adapter= nn.Sequential( nn.Linear(512, 128), nn.ReLU() )
        self.resnet.fc = nn.Sequential( nn.Linear(128, num_classes))
        self.conv1 = self.resnet.conv1
        self.bn1 = self.resnet.bn1
        self.relu = self.resnet.relu
        self.maxpool = self.resnet.maxpool
        self.layer1 = self.resnet.layer1
        self.layer2 = self.resnet.layer2
        self.layer3 = self.resnet.layer3
        self.layer4 = self.resnet.layer4
        self.avgpool = self.resnet.avgpool
        self.fc = self.resnet.fc

    def forward(self, x, no_fc=False):
        # See note [TorchScript super()]
        print(self)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.adapter(x)
        
        
        if no_fc:
            return x
        x = self.fc(x)

        return x


In [None]:
if not os.path.exists('./mlp_img'):
    os.mkdir('./mlp_img')

def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x


num_epochs = 10
batch_size = 512
learning_rate = 1e-3

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset_train = MNIST('./data', transform=img_transform, download=True,train = True)
dataset_test = MNIST('./data', transform=img_transform, download=True,train = False)

dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)


In [None]:
#a programm that trains the model on dataset_train and evaluates it on dataset_test

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet18(10).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    for i, (img, label) in enumerate(dataloader):
        img = img.to(device)
        label = label.to(device)

        output = model(img)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                  .format(epoch + 1, num_epochs, i + 1, len(dataloader), loss.item()))

    # Save the model checkpoints
    torch.save(model.state_dict(), './resnet.ckpt')

   




In [None]:
# Evaluate the model using the test dataset
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for img, label in dataloader:
        img = img.to(device)
        label = label.to(device)

        output = model(img)
        test_loss += criterion(output, label).item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(label.view_as(pred)).sum().item()

test_loss /= len(dataloader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
        .format(test_loss, correct, len(dataloader.dataset),
                100. * correct / len(dataloader.dataset)))

# Save the model checkpoints
torch.save(model.state_dict(), './resnet.ckpt')

# Evaluate the model using the test dataset
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for img, label in dataloader:
        img = img.to(device)