In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import imageio

In [2]:
# XOR dataset
X = torch.tensor([[0,0],[0,1],[1,0],[1,1]],dtype = torch.float32)
y = torch.tensor([0,1,1,0],dtype = torch.long)

In [3]:
# MLP with one hidden layer
class XORNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(2,4)         # hidden layer: takes 2 inputs, outputs 4 values
    self.fc2 = nn.Linear(4,2)         # output layer  takes 4 inputs, outputs 2 values

  def forward(self,x):
    x = torch.tanh(self.fc1(x))       # nonlinear hidden layer because of tanh
    x = self.fc2(x)                   # logits since it performs only linear transformation on output of previous layer
    return x

In [4]:
model = XORNN()
print(model)

XORNN(
  (fc1): Linear(in_features=2, out_features=4, bias=True)
  (fc2): Linear(in_features=4, out_features=2, bias=True)
)


In [5]:
optimizer = torch.optim.SGD(model.parameters(),lr = 0.1)    # torch.optim.SGD is a class that implements SGD
                                                            # model.parameters() pass all parameters of model to class so that it knows what to update
                                                            # learning rate = 0.1
criterion = nn.CrossEntropyLoss()                           # object of CrossEntropy loss uses softmax to convert logits to prob and then calculates negative log-likelyhood


In [7]:
frames = []

def plot_decision_boundary(epoch):
  xx,yy = np.meshgrid(np.linspace(-0.5,1.5,200),np.linspace(-0.5,1.5,200))
  grid = torch.tensor(np.c_[xx.ravel(),yy.ravel()],dtype = torch.float32)       # flatten's the xx and yy by ravel() .c_ creates pairs
  with torch.no_grad():
    logits = model(grid)
    Z = torch.argmax(F.softmax(logits,dim=1),axis=1).reshape(xx.shape)
    print(Z)

    plt.figure(figsize=(5,5))
    plt.contourf(xx,yy,Z,alpha = 0.6,cmap = "coolwarm")
    plt.scatter(X[:,0],X[:,1],c=y,cmap = "coolwarm",edgecolors = "k",s=100)
    plt.title(f"Decisio Boundary at Epoch {epoch}")
    plt.savefig(f"frame_{epoch}.png")
    plt.close()
    frames.append(imageio.imread(f"frame_{epoch}.png"))


In [8]:
# Training loop
epochs = 2000
for epoch in range(epochs):
  logits = model(X)
  loss = criterion(logits,y)          # CrossEntropy loss has a callable interface that allows objects to be used as interface

  optimizer.zero_grad()               # sets gradient to zero before computing new grad
  loss.backward()                     # computes gradients of loss
  optimizer.step()                    # updates the model parameters

  if epoch % 200 == 0:
    print(f"Epoch {epoch}/{epochs}, Loss = {loss.item():.4f}")
    plot_decision_boundary(epoch)


Epoch 0/2000, Loss = 0.7385
tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])


  frames.append(imageio.imread(f"frame_{epoch}.png"))


Epoch 200/2000, Loss = 0.6433
tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
Epoch 400/2000, Loss = 0.3670
tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])
Epoch 600/2000, Loss = 0.0779
tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])
Epoch 800/2000, Loss = 0.0310
tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])
Epoch 10

In [None]:
imageio.mimsave("xor_solution.gif",frames,fps=3)
print("✅ Training complete — GIF saved as xor_solution.gif")

✅ Training complete — GIF saved as xor_solution.gif
