In [1]:
import torch
import torch.nn as nn                 
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from model import KHopfield
# Define the Vision Transformer model

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

from tqdm import tqdm

# auto reload
%load_ext autoreload
%autoreload 2


In [50]:
# Define the Vision Transformer model
class KVisionTransformer(nn.Module):
    def __init__(self, num_classes, embed_dim, num_heads, img_size, patch_size):
        super(KVisionTransformer, self).__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_embedding = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.positional_embedding = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))
        self.hopfield = KHopfield(N=100, n=embed_dim * self.num_patches)
        self.fc = nn.Linear(embed_dim * self.num_patches, num_classes)
        self.num_heads = num_heads

    def forward(self, x):
        x1 = self.patch_embedding(x)  # (batch_size, embed_dim, num_patches_h, num_patches_w)
        x2 = x1.permute(0, 2, 3, 1)  # (batch_size, num_patches_h, num_patches_w, embed_dim)
        x3 = x2.reshape(x2.size(0), -1, x2.size(-1))  # (batch_size, num_patches, embed_dim)
        
        x4 = x3 + self.positional_embedding  # Add positional embedding
        # combine second and third dimension
        x5 = x4.flatten(1, 2)
        x6 = self.hopfield(x5, self.num_heads)
        x7 = x6.mean(dim=2)  # Global average pooling
        x8 = self.fc(x7)
        return x8
    
    def to(self, device):
        super(KVisionTransformer, self).to(device)
        self.hopfield = self.hopfield.to(device)
        return self
    
# Define the Vision Transformer model
class VisionTransformer(nn.Module):
    def __init__(self, num_classes, embed_dim, num_heads, img_size, patch_size):
        super(VisionTransformer, self).__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_embedding = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.positional_embedding = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))
        self.hopfield = KHopfield(N=100, n=embed_dim * self.num_patches)
        self.fc = nn.Linear(embed_dim * self.num_patches, num_classes)
        self.num_heads = num_heads

    def forward(self, x):
        x1 = self.patch_embedding(x)  # (batch_size, embed_dim, num_patches_h, num_patches_w)
        x2 = x1.permute(0, 2, 3, 1)  # (batch_size, num_patches_h, num_patches_w, embed_dim)
        x3 = x2.reshape(x2.size(0), -1, x2.size(-1))  # (batch_size, num_patches, embed_dim)
        
        x4 = x3 + self.positional_embedding  # Add positional embedding
        # combine second and third dimension
        x5 = x4.flatten(1, 2)
        x6 = self.hopfield(x5, self.num_heads)
        x7 = x6.mean(dim=2)  # Global average pooling
        x8 = self.fc(x7)
        return x8
    
    def to(self, device):
        super(VisionTransformer, self).to(device)
        self.hopfield = self.hopfield.to(device)
        return self

In [3]:
# Define the Vision Transformer model
class VisionTransformer(nn.Module):
    def __init__(self, num_classes, embed_dim, num_heads, num_layers, img_size, patch_size):
        super(VisionTransformer, self).__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embedding = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
        self.transformer = nn.Transformer(
            d_model=embed_dim,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            dim_feedforward=2048,
            dropout=0.1,
        )
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)  # (batch_size, embed_dim, num_patches_h, num_patches_w)
        x = x.permute(0, 2, 3, 1)  # (batch_size, num_patches_h, num_patches_w, embed_dim)
        x = x.reshape(x.size(0), -1, x.size(-1))  # (batch_size, num_patches, embed_dim)
        x = torch.cat([self.positional_embedding, x], dim=1)
        x = self.transformer(x)
        x = x.mean(dim=1)  # Global average pooling
        x = self.fc(x)
        return x


In [51]:
# Hyperparameters
batch_size = 64
num_epochs = 10
learning_rate = 1e-4
num_classes = 10
img_size = 32  # Assuming CIFAR-10 image size
patch_size = 16  # Adjust this based on your preference

# Data preprocessing
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize the model and optimizer
model = VisionTransformer(num_classes, embed_dim=256, num_heads=4, img_size=img_size, patch_size=patch_size)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

model = model.to(device)
print(device)
print(model.hopfield.memories.device)


Files already downloaded and verified
cuda:3
cuda:3


In [53]:
# Training loop
for epoch in range(num_epochs):
    # show loss in tqdm
    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for i, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        pbar.set_description(f'Epoch [{epoch + 1}/{num_epochs}] Loss: {loss.item():.4f}')
    print(f'Epoch [{epoch + 1}/{num_epochs}] Loss: {loss.item():.4f}')

print('Training finished!')

Epoch [1/10] Loss: 1.7319: 100%|██████████| 782/782 [00:31<00:00, 25.17it/s]


Epoch [1/10] Loss: 1.7319


Epoch [2/10] Loss: 1.9022: 100%|██████████| 782/782 [00:33<00:00, 23.37it/s]


Epoch [2/10] Loss: 1.9022


Epoch [3/10] Loss: 1.4298: 100%|██████████| 782/782 [00:33<00:00, 23.34it/s]


Epoch [3/10] Loss: 1.4298


Epoch [4/10] Loss: 1.2505: 100%|██████████| 782/782 [00:32<00:00, 23.87it/s]


Epoch [4/10] Loss: 1.2505


Epoch [5/10] Loss: 1.4505:  42%|████▏     | 329/782 [00:14<00:19, 23.16it/s]


KeyboardInterrupt: 

In [49]:
X = torch.nn.Parameter(torch.randn(59, 16))

hopfield = KHopfield(N=100, n=16 )

# optimize the hopfield network
optimizer = optim.Adam(hopfield.parameters(), lr=learning_rate)

for i in range(10):
    print(i)
    Y = hopfield(X, 4)
    loss = torch.norm(Y)
    loss.backward()
    print(loss)
    optimizer.step()

0
cpu cpu
tensor(0.9093, grad_fn=<CopyBackwards>)
1
cpu cpu
tensor(0.9053, grad_fn=<CopyBackwards>)
2
cpu cpu
tensor(0.9015, grad_fn=<CopyBackwards>)
3
cpu cpu
tensor(0.8977, grad_fn=<CopyBackwards>)
4
cpu cpu
tensor(0.8939, grad_fn=<CopyBackwards>)
5
cpu cpu
tensor(0.8901, grad_fn=<CopyBackwards>)
6
cpu cpu
tensor(0.8863, grad_fn=<CopyBackwards>)
7
cpu cpu
tensor(0.8824, grad_fn=<CopyBackwards>)
8
cpu cpu
tensor(0.8786, grad_fn=<CopyBackwards>)
9
cpu cpu
tensor(0.8747, grad_fn=<CopyBackwards>)


tensor([[ 0.0077,  0.0003,  0.0085,  ..., -0.0059, -0.0042,  0.0050],
        [ 0.0098,  0.0172,  0.0045,  ..., -0.0088, -0.0082, -0.0084],
        [ 0.0055, -0.0034,  0.0086,  ..., -0.0065,  0.0126,  0.0076],
        ...,
        [-0.0039, -0.0002, -0.0007,  ..., -0.0057, -0.0072, -0.0051],
        [ 0.0059, -0.0109,  0.0126,  ..., -0.0065,  0.0156, -0.0074],
        [-0.0067, -0.0028, -0.0106,  ..., -0.0015,  0.0015, -0.0081]],
       grad_fn=<MulBackward0>)
