In [None]:
import torchvision.models as models
import os
import wandb
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
from torchvision.utils import save_image
import matplotlib.pyplot as plt


In [None]:
class ResNet18(nn.Module):
    def __init__(self, num_classes):
        super(ResNet18, self).__init__()
        self.resnet = models.resnet18(pretrained=False, num_classes=num_classes)
        self.adapter= nn.Sequential( nn.Linear(512, 128), nn.ReLU() )
        self.resnet.fc = nn.Sequential( nn.Linear(128, num_classes))
        self.conv1 = self.resnet.conv1
        self.bn1 = self.resnet.bn1
        self.relu = self.resnet.relu
        self.maxpool = self.resnet.maxpool
        self.layer1 = self.resnet.layer1
        self.layer2 = self.resnet.layer2
        self.layer3 = self.resnet.layer3
        self.layer4 = self.resnet.layer4
        self.avgpool = self.resnet.avgpool
        self.fc = self.resnet.fc
        # print(self)

    def forward(self, x, no_fc=False):
        # See note [TorchScript super()]
        # print(self)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.adapter(x)
        
        
        if no_fc:
            return x
        x = self.fc(x)

        return x



In [None]:
if not os.path.exists('./mlp_img'):
    os.mkdir('./mlp_img')

def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 32, 32)
    return x


num_epochs = 10
batch_size = 512
learning_rate = 1e-3

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# dataset_train = MNIST('./data', transform=img_transform, download=True,train = True)
# dataset_test = MNIST('./data', transform=img_transform, download=True,train = False)

train_transform = transforms.Compose([
    transforms.RandomCrop(32,padding=4),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

dataset_train = CIFAR10('./data', transform=train_transform, download=True,train = True)
dataset_test = CIFAR10('./data', transform=test_transform, download=True,train = False)


# dataloader
train_loader = DataLoader(
    dataset_train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=6
)

test_loader = DataLoader(
    dataset_test,
    batch_size=100,
    shuffle=False,
    num_workers=6
)
# dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


In [None]:
# in this cell, I want to train the model and save the weights

device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model=ResNet18(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                  .format(epoch + 1, num_epochs, i + 1, len(train_loader), loss.item()))


In [None]:

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
    # save the model
    torch.save(model.state_dict(), './resnet18.pth')
    # save the image
    img = images[0].cpu()
    img = img.view(1, 3, 32, 32)
    save_image(img, './mlp_img/image_{}.png'.format(epoch))

In [None]:
# #a programm that trains the model on data

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = ResNet18(10).to(device)
# criterion = nn.MSELoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# for epoch in range(num_epochs):
#     for i, (img, label) in enumerate(train_loader):
#         img = img.to(device)
#         label = label.to(device)

#         output = model(img)
#         output = output.to(torch.float32)
#         label = label.to(torch.float32)
#         loss = criterion(output,label.view(-1,1))
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         if (i+1) % 100 == 0:
#             print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
#                   .format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))
    

   




In [None]:
# # Evaluate the model using the test dataset
# model.eval()
# test_loss = 0
# correct = 0
# with torch.no_grad():
#     for img, label in test_loader:
#         img = img.to(device)
#         label = label.to(device)
#         output = model(img)
#         test_loss += criterion(output, label.view(-1,1)).item()
#         pred = output.argmax(dim=1, keepdim=True)
#         correct += pred.eq(label.view_as(pred)).sum().item()
#         #show the image predicted and ground truth
#         plt.imshow(to_img(img[0]))
#         plt.tile(dataset_test.class_to_idx[label.cpu().data[0],  pred.cpu().data[0]])

# test_loss /= len(test_loader.dataset)
# print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
#         .format(test_loss, correct, len(test_loader.dataset),
#                 100. * correct / len(test_loader.dataset)))

# # Save the model checkpoints
# torch.save(model.state_dict(), './resnet.ckpt')
