In [109]:
import random

import numpy as np
import torch

import torch.nn as nn
import torchvision                                                       
import torchvision.transforms as transforms

In [2]:
device = torch.device("cuda")

In [179]:
class MyModel(nn.Module):
    def __init__(self, input_size, output_size, dtype):
        super().__init__()
        
        
        self.linear = nn.Linear(input_size, 3, bias=False)
        if dtype == torch.bfloat16:
            self.linear.weight = nn.Parameter(self.linear.weight * 3)
        else:
            self.linear.weight = nn.Parameter(self.linear.weight / 10.)
            
        self.non_lin = nn.ReLU()
        self.output = nn.Linear(3, output_size, bias=False)
        if dtype == torch.bfloat16:
            self.output.weight = nn.Parameter(self.output.weight * 3)
        else:
            self.output.weight = nn.Parameter(self.output.weight / 10.)
        
    def forward(self, inputs):
        inputs = self.linear(inputs)
        inputs = self.non_lin(inputs)
        return self.output(inputs)

In [180]:
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True
dataset_input = torch.randn((1000, 4), device=device)
dataset_label = (0.1 * (torch.randint(0, 10, (1000,), device=device)) + 5 ).long() * 10

In [189]:
data = {}
for j, dtype in enumerate((torch.float16, torch.bfloat16)):
    
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.backends.cudnn.deterministic = True
  
    model = MyModel(4, 60, dtype=dtype).to(device=device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scaler = torch.cuda.amp.GradScaler()
    
    f_loss = nn.CrossEntropyLoss()
    
    for epoch in range(9):
        for i in range(1000):

            with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
                output = model(dataset_input[i].unsqueeze(0))
                loss = f_loss(output, dataset_label[i].unsqueeze(0))

                # Backward and optimize
                optimizer.zero_grad()
                
                if dtype==torch.bfloat16:
                    loss.backward()
                    optimizer.step()
                else:
                    scaler.scale(loss).backward()

                    # scaler.step() first unscales the gradients of the optimizer's assigned params.
                    # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
                    # otherwise, optimizer.step() is skipped.
                    scaler.step(optimizer)

                    # Updates the scale for next iteration.
                    scaler.update()

        print(loss.item())
            
            
    for name, parameter in model.named_parameters():
        u, s, v = torch.svd_lowrank(parameter, q=3)
        
        data[f"{j}_{dtype}_{name}"] = parameter
    
        data[f"{j}_{dtype}_{name}_u"] = u
        data[f"{j}_{dtype}_{name}_s"] = s
        data[f"{j}_{dtype}_{name}_v"] = v
    
    print("finished")

3.751953125
2.7734375
1.1162109375
0.1075439453125
0.0081787109375
0.0009007453918457031
0.00014412403106689453
2.4437904357910156e-05
5.125999450683594e-06
finished
4.1875
4.0
3.65625
2.28125
0.2734375
0.01104736328125
0.0004291534423828125
1.239776611328125e-05
2.384185791015625e-07
finished


In [190]:
for x in ("u", "s", "v"):
    print(data[f'0_torch.float16_linear.weight_{x}'])

tensor([[-0.4757,  0.6781, -0.5603],
        [ 0.7616, -0.0011, -0.6480],
        [-0.4400, -0.7350, -0.5159]], device='cuda:0',
       grad_fn=<LinalgSvdBackward0>)
tensor([3.2457, 2.2845, 0.1098], device='cuda:0', grad_fn=<LinalgSvdBackward0>)
tensor([[-0.3599,  0.7246,  0.5300],
        [-0.5523, -0.4250,  0.4669],
        [ 0.1861, -0.5160,  0.5169],
        [ 0.7286,  0.1676,  0.4837]], device='cuda:0', grad_fn=<MmBackward0>)


In [191]:
for x in ("u", "s", "v"):
    print(data[f'1_torch.bfloat16_linear.weight_{x}'])

tensor([[-0.1856, -0.9520,  0.2434],
        [ 0.6623,  0.0618,  0.7467],
        [-0.7259,  0.2998,  0.6190]], device='cuda:0',
       grad_fn=<LinalgSvdBackward0>)
tensor([3.6283, 1.8956, 0.3372], device='cuda:0', grad_fn=<LinalgSvdBackward0>)
tensor([[-0.1361, -0.2753, -0.9333],
        [-0.1401, -0.1053,  0.2421],
        [ 0.4078,  0.8456, -0.2649],
        [ 0.8920, -0.4451,  0.0167]], device='cuda:0', grad_fn=<MmBackward0>)


In [192]:
for x in ("linear", "output"):
    print(data[f'0_torch.float16_{x}.weight'])

Parameter containing:
tensor([[ 1.6456,  0.1657, -1.1185, -0.8951],
        [-0.9291, -1.3975,  0.4246,  1.7662],
        [-0.7328,  1.4759,  0.5712, -1.3493]], device='cuda:0',
       requires_grad=True)
Parameter containing:
tensor([[-3.0909, -3.2617, -2.9867],
        [-3.0261, -3.1987, -2.9237],
        [-3.0696, -3.2513, -2.9388],
        [-2.9730, -3.2431, -2.9065],
        [-3.0353, -3.2192, -2.9048],
        [-3.0885, -3.2604, -2.9769],
        [-3.0623, -3.1596, -3.0052],
        [-3.0633, -3.2635, -3.0158],
        [-3.0674, -3.1641, -2.9391],
        [-3.0043, -3.2185, -2.9878],
        [-3.0221, -3.2824, -2.9980],
        [-3.0613, -3.1806, -2.9290],
        [-3.0543, -3.2264, -2.9223],
        [-2.9691, -3.2000, -2.9461],
        [-2.9860, -3.2657, -2.9405],
        [-3.0804, -3.2640, -2.9919],
        [-3.0080, -3.1943, -2.9946],
        [-3.0149, -3.1858, -2.9676],
        [-3.0253, -3.2099, -2.9217],
        [-2.9709, -3.2762, -2.9719],
        [-3.0029, -3.1717, -2.906

In [193]:
for x in ("linear", "output"):
    print(data[f'1_torch.bfloat16_{x}.weight'])

Parameter containing:
tensor([[ 0.5118,  0.3042, -1.8223,  0.2039],
        [-0.5942, -0.2880,  1.0123,  2.0953],
        [ 0.0071,  0.3596, -0.6487, -2.5987]], device='cuda:0',
       requires_grad=True)
Parameter containing:
tensor([[-4.4419, -4.9634, -4.3817],
        [-2.9160, -3.0629, -2.7884],
        [-4.1905, -4.6531, -3.2930],
        [-0.4725, -3.3024, -0.9470],
        [-3.3107, -3.7164, -2.3899],
        [-4.4433, -4.9202, -4.1564],
        [-3.5436, -1.4986, -4.4286],
        [-3.4613, -5.0607, -5.1958],
        [-4.0473, -1.5142, -2.7730],
        [-1.8561, -3.8900, -4.3769],
        [-2.3192, -5.5579, -4.8079],
        [-4.0529, -2.3221, -2.7766],
        [-3.8760, -3.9557, -2.8824],
        [-0.2306, -2.4180, -2.1388],
        [-1.2093, -4.8995, -2.8423],
        [-4.1038, -5.0288, -4.5613],
        [-2.0090, -3.0834, -4.5075],
        [-2.3666, -2.6742, -3.8056],
        [-2.8977, -3.4578, -2.7808],
        [-0.3357, -4.9915, -3.4214],
        [-1.7502, -1.6765, -1.557

In [176]:
def fp16_to_bf16(weight, q=6):
    # basic: convert each sep
    
    u, s, v = torch.svd_lowrank(weight, q=q, niter=10)
    return u, s, v

In [177]:
for name, parameter in model.named_parameters():
    u, s, v = fp16_to_bf16(parameter) 
    
    

AssertionError: (torch.Size([3, 3]), 6)

In [178]:
for name, parameter in model.named_parameters():
    #bf16_weight = fp16_to_bf16(parameter) 
    print(parameter)

Parameter containing:
tensor([[ 1.0611,  0.1330, -1.7605,  0.1466],
        [-0.7716, -0.2960,  1.1901,  2.2352],
        [-0.0403,  0.3733, -0.5693, -2.6609]], device='cuda:0',
       requires_grad=True)
Parameter containing:
tensor([[-3.9159, -4.2902, -4.0074],
        [-2.6747, -2.9679, -2.6374],
        [-3.6547, -4.0776, -3.1965],
        [-1.1376, -3.5658, -1.4755],
        [-2.9043, -3.4260, -2.3562],
        [-3.9046, -4.2604, -3.8511],
        [-3.3239, -1.9507, -4.1121],
        [-3.2540, -4.3546, -4.4939],
        [-3.6044, -1.9820, -2.9902],
        [-2.0055, -3.5081, -3.7839],
        [-2.3974, -4.7042, -4.1317],
        [-3.5135, -2.4737, -2.8746],
        [-3.3641, -3.5888, -2.8332],
        [-0.9957, -2.8861, -2.3710],
        [-1.5866, -4.2982, -2.6797],
        [-3.6862, -4.3365, -4.1030],
        [-2.1020, -2.9508, -3.8896],
        [-2.3492, -2.6882, -3.4017],
        [-2.6588, -3.2364, -2.6149],
        [-1.0227, -4.4652, -3.2033],
        [-1.9665, -2.1739, -1.842

In [74]:
#bf16
Parameter containing:
tensor([[-0.1229, -0.6971, -0.0126,  0.3644],
        [ 0.4843,  0.5767,  0.0836,  0.4326],
        [-0.1458, -0.1250, -0.5431, -0.6094]], device='cuda:0',
       requires_grad=True)


#fb16
tensor([[-0.6800,  0.8049,  0.4577,  0.3503],
        [ 0.2505, -0.3497,  0.0312,  0.3005],
        [-0.5663,  0.8540,  0.1833,  0.6032]], device='cuda:0',
       requires_grad=True)

SyntaxError: invalid syntax (3891572883.py, line 1)

In [None]:
torch.svd_lowrank(torch)