In [None]:
import os
import torch
import numpy as np

def load_tensors_from_folder(folder):
    tensors = []
    for filename in os.listdir(folder):
        if filename.endswith(".pt"):
            tensor = torch.load(os.path.join(folder, filename))
            tensors.append(tensor)
    return tensors

def check_same_size(tensors):
    first_tensor_shape = tensors[0].shape
    for tensor in tensors:
        if tensor.shape != first_tensor_shape:
            return False
    return True

def average_tensors(tensors):
    stacked_tensors = torch.stack(tensors)
    avg_tensor = torch.mean(stacked_tensors, dim=0)
    return avg_tensor

folder = "2024"
tensors = load_tensors_from_folder(folder)

if not tensors:
    raise ValueError("No .pt files found in the folder.")

if not check_same_size(tensors):
    raise ValueError("Not all tensors are of the same size.")

average_tensor = average_tensors(tensors)
torch.save(average_tensor, "average_tensor.pt")


In [None]:
import os
import torch
import time
def is_float(value):
    try:
        float(value)
        return True
    except ValueError:
        return False

# Function to extract index and loss from the filename
def extract_info(filename):
    parts = filename.split('-')
    index = int(parts[1])
    loss = float(parts[2].replace(".pt",""))
    ctime = float(parts[0].split("_")[0])
    return [filename, index, loss, ctime]

# Get all relevant files
path = "2024"
files = [f for f in os.listdir(path) if f.endswith('.pt') and len(f.split('-')) == 3 and is_float(f.split('-')[2].replace(".pt",""))]

# Extract info and sort files by index
file_info = [extract_info(f) for f in files]

file_info = [f for f in file_info if f[3] > 1720587980]

file_info_sorted = sorted(file_info, key=lambda x: x[1])  # Sort by index

for i in range(1, len(file_info_sorted)):
    file_info_sorted[i - 1][2] = file_info_sorted[i][2]
    
# Sort files by loss to take the smallest 20%
file_info_sorted_by_loss = sorted(file_info, key=lambda x: x[2])

# Take the smallest 20% of losses
num_files = len(file_info_sorted_by_loss)
smallest_20_percent = file_info_sorted_by_loss[:max(1, num_files // 5)]
print(smallest_20_percent)
# Average the tensors
tensors = [torch.load(os.path.join(path, f[0])) for f in smallest_20_percent]
average_tensor = torch.sum(torch.stack(tensors), dim=0) / len(tensors)

# Save the averaged tensor
now = time.time()
torch.save(average_tensor, f"{now}_average_tensor.pt")

print(f"Average tensor saved to {now}_average_tensor.pt")
