# Package

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import csv
import numpy as np

import matplotlib.pyplot as plt
import numpy as np

In [None]:
class Network(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x

In [None]:
model = Network()

In [None]:
type(model)

In [None]:
for name, param in model.named_parameters():
    print(f"{name}: {param.size()}")

In [None]:
def get_weight(config):
    AUs_num = config.DATA.CLASS_NUM
    aus = ['AU'+str(au_name) for au_name in config.DATA.AU_LIST]
    with open(config.DATA.SOURCE.TRAIN_LIST, 'r') as f:
        reader = csv.DictReader(f)
        labels = np.array([[int(row[au]) for au in aus] for row in reader])
        all_ = [len(labels)] * AUs_num
        positive = np.sum(labels, 0)

    negative = np.array(all_) - positive

    weight_cls = WeightNorm(positive.tolist(), negative.tolist())
    norm_weight = weight_cls.normalize()
    norm_weight =  np.array(norm_weight)
    norm_weight = norm_weight.tolist()
    norm_weight = torch.FloatTensor(norm_weight)
    return norm_weight

# 修改tensor的维度的函数

In [None]:
import torch

In [None]:
# 1.reshape

x1 = torch.randn((128, 512, 7, 7))
x2 = torch.randn((128, 512, 7, 7))

In [None]:
torch.stack([x1, x2]).shape  # (2, 128, 512, 7, 7)

# 余弦退火算法

In [None]:
model = nn.Linear(256, 5)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.1)

epochs = 100

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

In [None]:
y = [scheduler.get_last_lr()[0]]
print(f"Initial Learning Rate: {scheduler.get_last_lr()[0]:.6f}")
for epoch in range(epochs):
    # 训练模型
    optimizer.step()
    
    # 更新学习率
    scheduler.step()
    
    # 打印当前学习率
    print(f"Epoch {epoch+1}/{epochs}, Learning Rate: {scheduler.get_last_lr()[0]:.6f}")
    y.append(scheduler.get_last_lr()[0])

In [None]:
plt.plot(range(epochs+1), y)

# Transforming and augmenting images

In [None]:
import torch
from torchvision.transforms import v2

In [None]:
H, W = 32, 32
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)

In [None]:
plt.imshow(img.permute(1, 2, 0))

# Cross Entropy

In [None]:
input = torch.randn(3, 3, requires_grad=True)
target = torch.randint(3, (3,), dtype=torch.int64)
loss = F.cross_entropy(input, target, reduction='none')
loss

In [None]:
input

In [None]:
target

In [None]:
target = torch.randint(2, (3, 3), dtype=torch.int64)

In [57]:
target.bool()

tensor([[ True, False, False],
        [False, False,  True],
        [ True,  True,  True]])

In [53]:
row_1 = input[0]

In [54]:
row_1_log_softmax = F.log_softmax(input)
row_1_log_softmax

  row_1_log_softmax = F.log_softmax(input)


tensor([[-0.2687, -1.7151, -2.8881],
        [-0.9501, -1.6859, -0.8486],
        [-1.7369, -0.3728, -2.0015]], grad_fn=<LogSoftmaxBackward0>)

In [None]:
F.nll_loss(row_1_log_softmax, target, reduction='none')

In [71]:
target

tensor([[1, 0, 0],
        [0, 0, 1],
        [1, 1, 1]])

In [72]:
target /target.sum(dim=1).unsqueeze(1)

tensor([[1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000],
        [0.3333, 0.3333, 0.3333]])

In [76]:
target.sum(dim=1)

tensor([1, 1, 3])

In [58]:
row_1_log_softmax[target.bool()]

tensor([-0.2687, -0.8486, -1.7369, -0.3728, -2.0015], grad_fn=<IndexBackward0>)

In [61]:
pos = -torch.where(target.bool(), row_1_log_softmax, torch.tensor(0))

In [62]:
pos

tensor([[0.2687, -0.0000, -0.0000],
        [-0.0000, -0.0000, 0.8486],
        [1.7369, 0.3728, 2.0015]], grad_fn=<NegBackward0>)

In [75]:
pos.sum(dim=1) 

tensor([0.2687, 0.8486, 4.1112], grad_fn=<SumBackward1>)

In [77]:
pos.sum(dim=1)  / target.sum(dim=1)

tensor([0.2687, 0.8486, 1.3704], grad_fn=<DivBackward0>)

In [64]:
pos.mean()

tensor(0.5809, grad_fn=<MeanBackward0>)

In [74]:
torch.div(12, 0+1e-8)

tensor(1.2000e+09)