In [2]:
import torch
import torch.utils
import torch.utils.data
from torchvision import datasets, transforms
import torchvision
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader



BATCH_SIZE = 32

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((28, 28)),
    transforms.ToTensor()
])

train_dataset = datasets.Omniglot(root='./data', background=True, transform=transform, download=True)
test_dataset = datasets.Omniglot(root='./data', background=False, transform=transform, download=True)


train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


Using device: cpu


In [None]:
import matplotlib.pyplot as plt
Image, Label = next(iter(train_loader))
grid_img = torchvision.utils.make_grid(Image, nrow=8, normalize=True)

# Convert the grid image tensor to a NumPy array for plotting
np_img = grid_img.permute(1, 2, 0).numpy()

plt.figure(figsize=(10, 10))
plt.imshow(np_img, cmap='gray')
plt.title('Batch of Images from train_loader')
plt.show()

In [None]:
import torch.nn as nn
from torch_geometric.nn.conv import GATConv
from torch_geometric.nn import global_mean_pool

class GATfewshot(nn.Module):

    def __init__(self, in_channels, hidden_channels, out_channels, num_heads=4, dropout=0.3):
        super(GATfewshot, self).__init__()
        self.gat1 = GATConv(in_channels, hidden_channels, num_heads, False, dropout)
        self.gat2 = GATConv(hidden_channels, out_channels, 1, False, dropout)

    def forward(self, x, edge_index, batch_size):
        out = self.gat1(x, edge_index)
        out = torch.relu(out)
        out = self.gat2(out, edge_index)
        out = global_mean_pool(out, batch_size)
        return out
    


class Generator(nn.Module):
    def __init__(self, z_dim, hidden_dim, out_dim):
        super(Generator, self).__init__()
        self.l1 = nn.Linear(z_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, z):
        x = torch.relu(self.l1(z))
        x = torch.relu(self.l2(x))
        x = x.view(x.size(0), 1, 28, 28)
        x = torch.sigmoid(x)
        return x
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(128 * 7 * 7, 1)
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x




In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, hidden_dim, out_dim):
        super(Generator, self).__init__()
        self.l1 = nn.Linear(z_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, z):
        x = torch.relu(self.l1(z))
        x = torch.relu(self.l2(x))
        x = x.view(x.size(0), 1, 28, 28)
        x = torch.sigmoid(x)
        return x
    



In [None]:
gat_in_channels = 1
gat_hidden_channels = 64
gat_out_channels = 100

z_dim = 100
g_hidden_dim = 256
g_out_dim = 28*28

Disc = Discriminator()
Genr = Generator(z_dim,g_hidden_dim,g_out_dim)
GAT = GATfewshot(gat_in_channels,gat_hidden_channels,gat_out_channels)

Disc.to(device)
Genr.to(device)
GAT.to(device)


In [None]:
from itertools import product

def create_grid_edge_index(size):
    nodes = list(product(range(size), range(size)))
    edges = []
    
    for i, (x, y) in enumerate(nodes):
        if x < size - 1:
            edges.append((i, i + size))  # Connect to pixel below
        if y < size - 1:
            edges.append((i, i + 1))  # Connect to pixel to the right

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edge_index


size=28
edge_index = create_grid_edge_index(size)

In [None]:
class GraphDataset(torch.utils.data.Dataset):
    def __init__(self, images):
        self.images = images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        image = image.view(-1, 1)  # Each pixel is a node with 1 feature (its intensity)
        return Data(x=image, edge_index=edge_index)
    

# Convert Omniglot dataset to graph dataset
graph_dataset = GraphDataset([data[0] for data in train_dataset])
graph_loader = DataLoader(graph_dataset, batch_size=32, shuffle=True)

In [None]:
# Loss and optimizers
criterion = nn.BCELoss()
optimizer_d = torch.optim.Adam(Disc.parameters(), lr=0.002)
optimizer_g = torch.optim.Adam(Genr.parameters(), lr=0.002)
optimizer_gat = torch.optim.Adam(GAT.parameters(), lr=0.002)

In [None]:
edge_index.shape

In [None]:
nn.Upsample

In [None]:
# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    for data in graph_loader:
        images = data.x
        batch_size = images.size(0) // (28 * 28)
        images = images.view(batch_size, 28, 28, 1).permute(0, 3, 1, 2)
        
        labels = torch.ones(batch_size, 1)  # Real images
        fake_labels = torch.zeros(batch_size, 1)  # Fake images
        
        # Train Discriminator
        optimizer_d.zero_grad()
        real_outputs = Disc(images)
        d_loss_real = criterion(real_outputs, labels)
        d_loss_real.backward()
        
        z = torch.randn(batch_size, z_dim)
        fake_images = Genr(z)
        fake_outputs = Disc(fake_images.detach())
        d_loss_fake = criterion(fake_outputs, fake_labels)
        d_loss_fake.backward()
        optimizer_d.step()
        
        # Train Generator
        optimizer_g.zero_grad()
        fake_outputs = Disc(fake_images)
        g_loss = criterion(fake_outputs, labels)
        g_loss.backward()
        optimizer_g.step()
        
        # Train GAT
        optimizer_gat.zero_grad()
        g_features = GAT(data.x.to(device), data.edge_index.to(device), data.batch.to(device))
        g_features = g_features.view(batch_size, z_dim)  # Reshape to match the Generator input
        fake_images = Genr(g_features)
        gat_outputs = Disc(fake_images)
        gat_loss = criterion(gat_outputs, labels)
        gat_loss.backward()
        optimizer_gat.step()
        
    print(f'Epoch [{epoch+1}/{num_epochs}] | D Loss: {d_loss_real.item() + d_loss_fake.item()} | G Loss: {g_loss.item()} | GAT Loss: {gat_loss.item()}')

# Evaluation
with torch.no_grad():
    z = torch.randn(16, z_dim)
    fake_images = Genr(z)
    grid_img = torchvision.utils.make_grid(fake_images, nrow=4, normalize=True)
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.title('Generated Images')
    plt.show()