**1. Load and normalize CIFAR10**

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [2]:
#define our device as the first visible cuda device if we have CUDA available

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)

cuda:0


In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


**2. Define a Convolutional Neural Network**

In [4]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)  # 输出: 32x32x16
        self.pool1 = nn.MaxPool2d(2, 2)  # 输出: 16x16x16
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)  # 输出: 16x16x32
        self.pool2 = nn.MaxPool2d(2, 2)  # 输出: 8x8x32
        self.conv3 = nn.Conv2d(32, 16, kernel_size=3, padding=1)  # 输出: 8x8x16
        self.pool3 = nn.MaxPool2d(2, 2)  # 输出: 4x4x16 => 展平后是 4*4*16 = 256 个元素
        self.conv4 = nn.Conv2d(16, 64, kernel_size=3, padding=1)  # 输出: 4x4x64
        # 展平后是 4*4*64 = 1024 个元素

        self.fc1 = nn.Linear(1024,243)
        self.fc2 = nn.Linear(243, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        x = F.relu(self.conv4(x))

        # 打印形状，确认展平后是1024
        # print(f"Shape before flatten: {x.shape}")
        
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net().to(device) #use GPU

In [5]:
class TTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)  # 输出: 32x32x16
        self.pool1 = nn.MaxPool2d(2, 2)  # 输出: 16x16x16
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)  # 输出: 16x16x32
        self.pool2 = nn.MaxPool2d(2, 2)  # 输出: 8x8x32
        self.conv3 = nn.Conv2d(32, 16, kernel_size=3, padding=1)  # 输出: 8x8x16
        self.pool3 = nn.MaxPool2d(2, 2)  # 输出: 4x4x16 => 展平后是 4*4*16 = 256 个元素
        self.conv4 = nn.Conv2d(16, 64, kernel_size=3, padding=1)  # 输出: 4x4x64
        
        import torchtt as tntt
        self.ttl1 = tntt.nn.LinearLayerTT([4,4,4,4,4], [3,3,3,3,3], [1,3,3,3,3,1])
        self.fc2 = nn.Linear(243, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        x = F.relu(self.conv4(x))
        x = torch.flatten(x, 1)
        # 重塑为TT层所需的形状
        batch_size = x.shape[0]
        x = x.reshape(batch_size, 4, 4, 4, 4, 4)  # 4^5 = 1024
        x = F.relu(self.ttl1(x))
        # 重塑回二维形状以通过全连接层
        x = x.reshape(batch_size, -1)  # 应该是 [batch_size, 243]
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


ttnet = TTNet().to(device) #use GPU

C++ implementation not available. Using pure Python.
[0m
C++ implementation not available. Using pure Python.
[0m
C++ implementation not available. Using pure Python.
[0m
C++ implementation not available. Using pure Python.
[0m
C++ implementation not available. Using pure Python.
[0m


**3. Define a Loss function and optimizer**

In [6]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer1 = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
optimizer2 = optim.SGD(ttnet.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)

**4. Train and test the network**

In [7]:
# 函数：计算准确率
def accuracy(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# 函数：计算模型的参数量
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [8]:
import time

num_epochs = 5
# 初始化 TensorBoard SummaryWriter
writer = SummaryWriter()

# 用于存储每个 epoch 的准确率和损失数据
accuracy_history_net = []
accuracy_history_ttnet = []
loss_history_net = []
loss_history_ttnet = []

# 用于存储训练时间
time_history_net = []
time_history_ttnet = []

# 训练循环
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}:')
    
    # ==================== 训练第一个模型(net) ====================
    start_time_net = time.time()  # 开始时间
    
    running_loss_net = 0.0
    total_loss_net = 0.0
    batch_count = 0
    net.train()  # 设置为训练模式
    
    for i, data in enumerate(trainloader, 0):
        # 获取输入数据
        inputs, labels = data[0].to(device), data[1].to(device)
        
        # 梯度清零
        optimizer1.zero_grad()
        
        # 前向传播 + 反向传播 + 优化
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer1.step()
        
        # 统计损失
        running_loss_net += loss.item()
        total_loss_net += loss.item()
        batch_count += 1
        
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1:5d}] net loss: {running_loss_net / 2000:.3f}')
            writer.add_scalar('Training_Loss_Detail', running_loss_net / 2000, epoch * len(trainloader) + i)
            running_loss_net = 0.0
    
    # 计算训练时间
    end_time_net = time.time()
    epoch_time_net = end_time_net - start_time_net
    time_history_net.append(epoch_time_net)
    writer.add_scalar('Training_Time/net', epoch_time_net, epoch)
    
    # 计算并记录平均每个epoch的损失
    epoch_loss_net = total_loss_net / batch_count
    loss_history_net.append(epoch_loss_net)
    writer.add_scalar('Loss/net', epoch_loss_net, epoch)
    
    # ==================== 训练第二个模型(ttnet) ====================
    start_time_ttnet = time.time()  # 开始时间
    
    running_loss_ttnet = 0.0
    total_loss_ttnet = 0.0
    batch_count = 0
    ttnet.train()  # 设置为训练模式
    
    for i, data in enumerate(trainloader, 0):
        # 获取输入数据
        inputs, labels = data[0].to(device), data[1].to(device)
        
        # 梯度清零
        optimizer2.zero_grad()
        
        # 前向传播 + 反向传播 + 优化
        outputs = ttnet(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer2.step()
        
        # 统计损失
        running_loss_ttnet += loss.item()
        total_loss_ttnet += loss.item()
        batch_count += 1
        
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1:5d}] ttnet loss: {running_loss_ttnet / 2000:.3f}')
            writer.add_scalar('Training_Loss_Detail', running_loss_ttnet / 2000, epoch * len(trainloader) + i)
            running_loss_ttnet = 0.0
    
    # 计算训练时间
    end_time_ttnet = time.time()
    epoch_time_ttnet = end_time_ttnet - start_time_ttnet
    time_history_ttnet.append(epoch_time_ttnet)
    writer.add_scalar('Training_Time/ttnet', epoch_time_ttnet, epoch)
    
    # 计算并记录平均每个epoch的损失
    epoch_loss_ttnet = total_loss_ttnet / batch_count
    loss_history_ttnet.append(epoch_loss_ttnet)
    writer.add_scalar('Loss/ttnet', epoch_loss_ttnet, epoch)
    
    # ==================== 评估模型 ====================
    net.eval()  # 设置为评估模式
    ttnet.eval()
    
    # 计算准确率
    acc_net = accuracy(net, testloader, device)
    acc_ttnet = accuracy(ttnet, testloader, device)
    
    # 记录准确率历史
    accuracy_history_net.append(acc_net)
    accuracy_history_ttnet.append(acc_ttnet)
    
    # 在同一个图表中记录两个模型的准确率
    writer.add_scalar('Accuracy/net', acc_net, epoch)
    writer.add_scalar('Accuracy/ttnet', acc_ttnet, epoch)
    
    # 打印当前 epoch 的结果
    print(f'  Net - Loss: {epoch_loss_net:.4f}, Acc: {acc_net:.4f}, Time: {epoch_time_net:.2f}s')
    print(f'  TTNet - Loss: {epoch_loss_ttnet:.4f}, Acc: {acc_ttnet:.4f}, Time: {epoch_time_ttnet:.2f}s')
    print(f'  Speed Ratio: {epoch_time_net/epoch_time_ttnet:.2f}x')
    print('-'*40)
    
    # 记录速度比（net训练时间/ttnet训练时间）
    speed_ratio = epoch_time_net / epoch_time_ttnet
    writer.add_scalar('Speed_Ratio/net_vs_ttnet', speed_ratio, epoch)

# ==================== 训练结束，计算总结统计 ====================
print('Finished Training')

# 计算总训练时间
total_time_net = sum(time_history_net)
total_time_ttnet = sum(time_history_ttnet)
avg_time_net = total_time_net / num_epochs
avg_time_ttnet = total_time_ttnet / num_epochs
overall_speed_ratio = total_time_net / total_time_ttnet

# 计算最终的模型参数量和压缩率
params_net = count_parameters(net)
params_ttnet = count_parameters(ttnet)
compression_ratio = params_net / params_ttnet

# 记录总体统计数据
writer.add_scalar('Training_Time_Total/net', total_time_net, 0)
writer.add_scalar('Training_Time_Total/ttnet', total_time_ttnet, 0)
writer.add_scalar('Training_Time_Average/net', avg_time_net, 0)
writer.add_scalar('Training_Time_Average/ttnet', avg_time_ttnet, 0)
writer.add_scalar('Speed_Ratio/overall', overall_speed_ratio, 0)
writer.add_scalar('Parameters/net', params_net, 0)
writer.add_scalar('Parameters/ttnet', params_ttnet, 0)
writer.add_scalar('Parameters/compression_ratio', compression_ratio, 0)

# 打印最终结果
print(f'Model Parameters:')
print(f'  Net: {params_net:,} parameters')
print(f'  TTNet: {params_ttnet:,} parameters')
print(f'  Compression Ratio: {compression_ratio:.2f}x')

print(f'Training Time:')
print(f'  Net: {total_time_net:.2f}s (avg: {avg_time_net:.2f}s/epoch)')
print(f'  TTNet: {total_time_ttnet:.2f}s (avg: {avg_time_ttnet:.2f}s/epoch)')
print(f'  Speed Ratio: {overall_speed_ratio:.2f}x')

print(f'Final Accuracy:')
print(f'  Net: {accuracy_history_net[-1]:.4f}')
print(f'  TTNet: {accuracy_history_ttnet[-1]:.4f}')

images = torch.rand([1, 3, 32, 32], dtype=torch.float32).to(device)
writer.add_graph(net, images)
writer.add_graph(ttnet, images)
# 关闭 TensorBoard writer
writer.close()

Epoch 1/5:
[1,  2000] net loss: 2.303
[1,  4000] net loss: 2.233
[1,  6000] net loss: 1.958
[1,  8000] net loss: 1.734
[1, 10000] net loss: 1.613
[1, 12000] net loss: 1.517
[1,  2000] ttnet loss: 2.279
[1,  4000] ttnet loss: 2.008
[1,  6000] ttnet loss: 1.747
[1,  8000] ttnet loss: 1.614
[1, 10000] ttnet loss: 1.525
[1, 12000] ttnet loss: 1.435
  Net - Loss: 1.8752, Acc: 0.4423, Time: 72.39s
  TTNet - Loss: 1.7546, Acc: 0.4929, Time: 106.60s
  Speed Ratio: 0.68x
----------------------------------------
Epoch 2/5:
[2,  2000] net loss: 1.432
[2,  4000] net loss: 1.378
[2,  6000] net loss: 1.351
[2,  8000] net loss: 1.314
[2, 10000] net loss: 1.263
[2, 12000] net loss: 1.226
[2,  2000] ttnet loss: 1.367
[2,  4000] ttnet loss: 1.315
[2,  6000] ttnet loss: 1.284
[2,  8000] ttnet loss: 1.238
[2, 10000] ttnet loss: 1.215
[2, 12000] ttnet loss: 1.191
  Net - Loss: 1.3218, Acc: 0.5744, Time: 75.93s
  TTNet - Loss: 1.2640, Acc: 0.5911, Time: 86.44s
  Speed Ratio: 0.88x
--------------------------

In [9]:
PATH1 = './cifar_net.pth'
PATH2 = './cifar_ttnet.pth'
torch.save(net.state_dict(), PATH1)
torch.save(ttnet.state_dict(), PATH2)

**5. Reload the network**

In [10]:
# net = Net()
# net.load_state_dict(torch.load(PATH1, weights_only=True))
# ttnet = TTNet()
# ttnet.load_state_dict(torch.load(PATH2, weights_only=True))