Codice con metodo Runge Kutta al quarto ordine scritto a mano. Non gode del vantaggio sulla memoria di odeint, ma in compenso è molto più veloce.

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import math

In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [3]:
class StochasticMNIST(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

# Carico le immagini encodate
train_data, train_targets = torch.load("qsw_train_data_encoded64.pt", weights_only=False)
test_data, test_targets = torch.load("qsw_test_data_encoded64.pt", weights_only=False)

# Creo i dataset
qsw_train_data = StochasticMNIST(train_data, train_targets)
qsw_test_data = StochasticMNIST(test_data, test_targets)

batch_size = 64

qsw_train_dataloader = DataLoader(qsw_train_data, batch_size=batch_size, shuffle=True, num_workers=32)
qsw_test_dataloader = DataLoader(qsw_test_data, batch_size=batch_size, shuffle=True, num_workers=32)

for X, y in qsw_train_dataloader:
    print(f'Shape dataset: {X.shape}, {y.shape}')
    print(f'Data type: {X.dtype}, {y.dtype}')
    print(f'Padding: {X[0,-10:]}')
    break

Shape dataset: torch.Size([64, 74]), torch.Size([64])
Data type: torch.complex64, torch.int64
Padding: tensor([0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j])


In [10]:
class SQWalker(nn.Module):
    r"""Evaluates the asymptotic configuration given initial data.
    """
    def __init__(self, in_features, out_features, dt=1e-3, steps=1000, l = 1, noise=1., device=None):
        self.factory_kwargs = {"device": device, "dtype": torch.complex64}
        super().__init__()
        # Inizializzo i parametri del modello
        self.n = math.isqrt(in_features)
        self.N = in_features
        self.s = out_features
        self.p = noise
        self.steps = steps
        self.dt = dt

        # Costruisco la matrice di transizione e la popolo con i pesi da trainare.
        # Per semplicità scelgo i primi 10 nodi come quelli collegati ai sink.
        self.sinkrates = nn.parameter.Parameter(
            torch.empty((1, self.s), requires_grad=True, **self.factory_kwargs)
        )
        self.mask_sinks = torch.zeros((self.s, self.N), **self.factory_kwargs)
        self.mask_sinks.fill_diagonal_(1)

        # Costruisco dei tensori ausiliari per l'integrazione.
        self.B = torch.zeros(self.N + self.s, self.N + self.s, **self.factory_kwargs)
        self.B[:self.N, :self.N] = 0.5 * self.p * torch.eye(self.N, **self.factory_kwargs)

        # Costruisco la matrice del lattice e la popolo con i pesi da trainare.
        self.mask = self.create_mask(l, device)
        self.expressivity = torch.count_nonzero(self.mask)
        self.weights = nn.parameter.Parameter(
            torch.empty((1, self.expressivity), requires_grad=True, **self.factory_kwargs)
        )
        self.reset_parameters()
        
    def create_mask(self, l, device):
        moves = torch.ones(2*l+1,2*l+1, device=device)
        moves[l,l] = 0
        M = nn.functional.pad(moves, (self.n-l-1, self.n-l-1, self.n-l-1, self.n-l-1), 'constant', 0)

        A = torch.zeros((self.N, self.N), device=device)
        for i in range(self.n):
            for j in range(self.n):
                A[i*self.n+j] = M[self.n-1-i:2*self.n-1-i, self.n-1-j:2*self.n-1-j].flatten()
        return A == 1
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.sinkrates, a=math.sqrt(5))

    def lindblad(self, rho, H, laplacian, a):
        # Calcolo drho/dt
        drho = -1j * (torch.matmul(H, rho) - torch.matmul(rho, H))
        drho = drho + torch.diag_embed(torch.matmul(torch.diagonal(rho, dim1 = 1, dim2 = 2), laplacian.T))
        drho = drho - (torch.matmul(a, rho) + torch.matmul(rho, a))
        return drho
    
    def forward(self, input, history = False):  
        # Costruisco le matrici A e gamma
        A = torch.zeros((self.N, self.N), **self.factory_kwargs)
        A[self.mask] = self.weights
        A = A - 1j * A.imag
        A = 0.5*(A + A.T)**2

        gamma = torch.zeros((self.s, self.N), **self.factory_kwargs)
        gamma = torch.zeros((self.s, self.N), **self.factory_kwargs)
        gamma[self.mask_sinks == 1] = self.sinkrates
        gamma = gamma - 1j * gamma.imag
        gamma = (gamma)**2

        # Calcolo le connettività delle matrici A e gamma
        degA = torch.sum(A, dim=0) 
            
        # Aggiorno la laplaciana
        laplacian = torch.zeros(self.N + self.s, self.N + self.s, **self.factory_kwargs)
        laplacian[:self.N, :self.N] = self.p * A / degA
        laplacian[self.N:, :self.N]= gamma

        # Calcolo la matrice a del foglio LaTeX.
        self.R = torch.zeros(self.N + self.s, self.N + self.s, **self.factory_kwargs)
        self.R[self.N:, :self.N] = gamma
        a = 0.5 * torch.diag(torch.sum(self.R, dim=0)) + self.B

        # Aggiorno l'hamiltoniana
        H = torch.eye(self.N + self.s, **self.factory_kwargs)
        H[:self.N, :self.N] = (1 - self.p) * A

        # Costruisco la matrice rho0 e integro con Runge-Kutta 4(5)
        rho = torch.diag_embed(input)

        if history:
            movie = [rho]

        for _ in range(self.steps-1):
            k1 = self.dt * self.lindblad(rho, H, laplacian, a)
            k2 = self.dt * self.lindblad(rho + k1/2, H, laplacian, a)
            k3 = self.dt * self.lindblad(rho + k2/2, H, laplacian, a)
            k4 = self.dt * self.lindblad(rho + k3, H, laplacian, a)
            
            rho = rho + (k1 + 2*k2 + 2*k3 + k4) / 6
            if history:
                movie.append(rho)
        if history:
            return torch.cat(movie)
        return rho.diagonal(dim1=-2, dim2=-1)[:, -self.s:].real

    def extra_repr(self) -> str:
        return "in_features: {a}, out_features: {b}\n Expressivity: {c:.2f}%".format(a = self.N, b = self.s, c = self.expressivity*100/self.N**2)

In [11]:
model = SQWalker(64, 10, dt = 1., steps = 50, l = 1, noise = 1., device=device).to(device)
print(model)

SQWalker(
  in_features: 64, out_features: 10
   Expressivity: 10.25%
)


In [16]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [13]:
def train(dataloader, model, loss_fn, optimizer, voice = True):
    size = len(dataloader.dataset)
    model.train()
    current = 0
    i = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        current += len(X)
        if current >= i*10000 and voice:
            i += 1
            loss = loss.item()
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [14]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return 100*correct, test_loss

In [17]:
epochs = 50
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(qsw_train_dataloader, model, loss_fn, optimizer)
    test(qsw_test_dataloader, model, loss_fn)
print('Done!')

loss: 1.874731  [   64/60000]
loss: 1.926513  [10048/60000]
loss: 1.838157  [20032/60000]
loss: 1.860362  [30016/60000]
loss: 1.974287  [40000/60000]
loss: 1.866854  [50048/60000]
loss: 1.940316  [60000/60000]
Test Error: 
 Accuracy: 38.1%, Avg loss: 1.942555 

loss: 1.886816  [   64/60000]
loss: 2.041170  [10048/60000]
loss: 1.945514  [20032/60000]
loss: 1.921605  [30016/60000]
loss: 1.962546  [40000/60000]
loss: 1.998242  [50048/60000]
loss: 2.040855  [60000/60000]
Test Error: 
 Accuracy: 35.7%, Avg loss: 1.939114 

loss: 1.965268  [   64/60000]
loss: 1.786875  [10048/60000]
loss: 1.918204  [20032/60000]
loss: 2.066149  [30016/60000]
loss: 1.894915  [40000/60000]
loss: 1.912167  [50048/60000]
loss: 1.912428  [60000/60000]
Test Error: 
 Accuracy: 33.2%, Avg loss: 1.938927 

loss: 1.998522  [   64/60000]
loss: 2.005302  [10048/60000]
loss: 2.023532  [20032/60000]
loss: 1.744122  [30016/60000]
loss: 2.018036  [40000/60000]
loss: 1.970902  [50048/60000]
loss: 2.037971  [60000/60000]
Test

KeyboardInterrupt: 

In [10]:
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

Saved PyTorch Model State to model.pth


In [None]:
model = SQWalker(64, 10, dt = 0.001, steps = 1000, l = 1, noise = 1., device=device).to(device)
model.load_state_dict(torch.load("model.pth", weights_only=True))

In [None]:
classes = [
    "0",
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    "9",
]

model.eval()
n_datas = len(test_data)

c = torch.randint(n_datas,(1,)).item()
x, y = qsw_test_data[c]
with torch.no_grad():
    x = x.to(device)
    pred = model(x.unsqueeze(0), history=True).squeeze(1)
    sinks = pred[-1].diagonal(dim1=-2, dim2=-1)[-10:].real
    predicted, actual = classes[sinks.argmax()], classes[y]

print(f'Convergenza: {(100*torch.sum(sinks)/torch.sum(x.real)):.2f}%')
print(f'Stato finale: {sinks}')