In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
# import matplotlib.pyplot as plt
from tqdm import tqdm
from htorch import layers
import time

In [2]:
device = torch.device('cuda:0')

data = torch.tensor(pd.read_csv("../data/mnist/train.csv", header=None).values)
x, y = (data[:, 1:]/255).float().to(device), torch.nn.functional.one_hot(data[:, 0].long(), 10).to(device)

test = torch.tensor(pd.read_csv("../data/mnist/test.csv", header=None).values)
x_test, y_test = (test[:, 1:]/255).float().to(device), torch.nn.functional.one_hot(test[:, 0].long(), 10).to(device)

In [3]:
class Real(nn.Module): 
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 300)
        self.fc2 = nn.Linear(300, 100)
        self.fc4 = nn.Linear(100, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc4(x)
        return x
real_model = Real()
real_model.to(device)
optimiser = torch.optim.Adam(real_model.parameters(), lr=1.2e-3)

In [4]:
class Quat(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = layers.QLinear(196, 75)
        self.fc2 = layers.QLinear(75, 25)
        self.fc3 = nn.Linear(100, 10)

    def forward(self, x):
        # x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
quat_model = Quat()
quat_model.to(device)
optimiser = torch.optim.Adam(quat_model.parameters(), lr=1.2e-3)

In [5]:
for batch_size in [2**(i) for i in range(5, 14)][::-1]:
    for model in ["real", "quat"]:
        if model == "quat":
            model = Quat()
            model.to(device)
            optimiser = torch.optim.Adam(model.parameters(), lr=1.2e-3)
        elif model == "real":
            model = Real()
            model.to(device)
            optimiser = torch.optim.Adam(model.parameters(), lr=1.2e-3)
        
        t_f = 0
        t_b = 0
        
        for epoch in tqdm(range(1000)):
            for i in range(0, len(x), batch_size):
                batch_x, batch_y = x[i:i+batch_size], y[i:i+batch_size].float()
                optimiser.zero_grad()
                
                t0 = time.time()
                output = model(batch_x)
                t_f += time.time() - t0
                
                t0 = time.time()
                loss = F.mse_loss(output, batch_y)
                loss.backward()
                t_b += time.time() - t0
                
                optimiser.step()
            # losses.append(loss.item())
        if isinstance(model, Real):
            message = f"""<tr>
            <td rowspan="2"><b>{batch_size}</b></td>
            <td><b>Real</b></td>
            <td>{t_f:.3f}ms</td>
            <td>{t_b:.3f}ms</td>
            </tr>"""
        else:
            message = f"""<tr>
            <td><b>Quat</b></td>
            <td>{t_f:.3f}ms</td>
            <td>{t_b:.3f}ms</td>
            </tr>"""
        print(message)

100%|██████████| 1000/1000 [00:07<00:00, 125.57it/s]


<tr>
            <td rowspan="2"><b>8192</b></td>
            <td><b>Real</b></td>
            <td>2.833ms</td>
            <td>1.224ms</td>
            </tr>


100%|██████████| 1000/1000 [00:09<00:00, 110.45it/s]


<tr>
            <td><b>Quart</b></td>
            <td>1.741ms</td>
            <td>4.409ms</td>
            </tr>


100%|██████████| 1000/1000 [00:09<00:00, 105.14it/s]


<tr>
            <td rowspan="2"><b>4096</b></td>
            <td><b>Real</b></td>
            <td>2.824ms</td>
            <td>1.982ms</td>
            </tr>


100%|██████████| 1000/1000 [00:12<00:00, 81.27it/s]


<tr>
            <td><b>Quart</b></td>
            <td>2.479ms</td>
            <td>4.425ms</td>
            </tr>


100%|██████████| 1000/1000 [00:12<00:00, 78.14it/s]


<tr>
            <td rowspan="2"><b>2048</b></td>
            <td><b>Real</b></td>
            <td>1.712ms</td>
            <td>5.231ms</td>
            </tr>


100%|██████████| 1000/1000 [00:24<00:00, 41.33it/s]


<tr>
            <td><b>Quart</b></td>
            <td>4.904ms</td>
            <td>8.499ms</td>
            </tr>


100%|██████████| 1000/1000 [00:20<00:00, 47.81it/s]


<tr>
            <td rowspan="2"><b>1024</b></td>
            <td><b>Real</b></td>
            <td>2.488ms</td>
            <td>6.930ms</td>
            </tr>


100%|██████████| 1000/1000 [00:48<00:00, 20.81it/s]


<tr>
            <td><b>Quart</b></td>
            <td>9.589ms</td>
            <td>17.100ms</td>
            </tr>


100%|██████████| 1000/1000 [00:41<00:00, 24.06it/s]


<tr>
            <td rowspan="2"><b>512</b></td>
            <td><b>Real</b></td>
            <td>4.959ms</td>
            <td>13.531ms</td>
            </tr>


100%|██████████| 1000/1000 [01:34<00:00, 10.57it/s]


<tr>
            <td><b>Quart</b></td>
            <td>19.205ms</td>
            <td>33.230ms</td>
            </tr>


100%|██████████| 1000/1000 [01:22<00:00, 12.15it/s]


<tr>
            <td rowspan="2"><b>256</b></td>
            <td><b>Real</b></td>
            <td>10.350ms</td>
            <td>26.507ms</td>
            </tr>


100%|██████████| 1000/1000 [03:07<00:00,  5.33it/s]


<tr>
            <td><b>Quart</b></td>
            <td>38.441ms</td>
            <td>65.430ms</td>
            </tr>


100%|██████████| 1000/1000 [02:43<00:00,  6.11it/s]


<tr>
            <td rowspan="2"><b>128</b></td>
            <td><b>Real</b></td>
            <td>20.993ms</td>
            <td>51.605ms</td>
            </tr>


100%|██████████| 1000/1000 [06:11<00:00,  2.69it/s]


<tr>
            <td><b>Quart</b></td>
            <td>75.729ms</td>
            <td>129.030ms</td>
            </tr>


100%|██████████| 1000/1000 [05:20<00:00,  3.12it/s]


<tr>
            <td rowspan="2"><b>64</b></td>
            <td><b>Real</b></td>
            <td>40.321ms</td>
            <td>99.586ms</td>
            </tr>


100%|██████████| 1000/1000 [12:11<00:00,  1.37it/s]


<tr>
            <td><b>Quart</b></td>
            <td>148.626ms</td>
            <td>253.968ms</td>
            </tr>


100%|██████████| 1000/1000 [10:39<00:00,  1.56it/s]


<tr>
            <td rowspan="2"><b>32</b></td>
            <td><b>Real</b></td>
            <td>80.087ms</td>
            <td>200.908ms</td>
            </tr>


100%|██████████| 1000/1000 [24:27<00:00,  1.47s/it]

<tr>
            <td><b>Quart</b></td>
            <td>297.166ms</td>
            <td>510.735ms</td>
            </tr>



