In [1]:
import ollama
from tqdm import tqdm

In [2]:
modelfile = """
FROM llama3.1
SYSTEM Sort the input list of numbers in ascending order. The list is comma-separated. Only return the list of numbers, not the input prompt.
"""

ollama.create(model="sortLlama", modelfile=modelfile)

{'status': 'success'}

In [3]:
import re

def sortLlama(list):
    list_str = ",".join([str(i) for i in list])
    response = ollama.chat(
        model="sortLlama",
        messages=[
            {
                "role": "user",
                "content": list_str,
            },
        ],
    )
    list_response = response["message"]["content"]
    list_response = re.sub(r'[^0-9,]', '', list_response)
    return [int(i) for i in list_response.split(",")]

sortLlama([3, 2, 1])

[1, 2, 3]

In [4]:
import random
import time

# Function to measure the time and length of random lists
def measure_time_and_length(n=10):
    # Generate a random list
    random_list = [random.randint(1, 100) for _ in range(n)]

    # Measure the time taken to generate the list
    start_time = time.time()
    try:
        sorted_ai_list = sortLlama(random_list)
    except Exception as e:
        print("Error:", e)
        sorted_ai_list = []
    end_time = time.time()
    time_taken = end_time - start_time

    sorted_list = sorted(random_list)

    # Get the length of the list
    list_length = len(random_list)
    is_correct = sorted_list == sorted_ai_list

    return time_taken, list_length, is_correct

# Call the function and print the results
all_points = []
for n in tqdm(range(1, 5000, 10)):
    time_taken, list_length, is_correct = measure_time_and_length(n)
    all_points.append((time_taken, list_length, is_correct))

  6%|▋         | 32/500 [04:25<1:27:18, 11.19s/it]

Error: invalid literal for int() with base 10: ''


  7%|▋         | 37/500 [05:23<1:30:13, 11.69s/it]

Error: invalid literal for int() with base 10: ''


 20%|██        | 102/500 [19:14<1:26:07, 12.98s/it]

Error: invalid literal for int() with base 10: ''


 21%|██        | 103/500 [19:20<1:13:08, 11.05s/it]

Error: invalid literal for int() with base 10: ''


 21%|██        | 104/500 [19:22<55:24,  8.39s/it]  

Error: invalid literal for int() with base 10: ''


 21%|██        | 105/500 [19:27<47:28,  7.21s/it]

Error: invalid literal for int() with base 10: ''


 21%|██        | 106/500 [19:33<45:23,  6.91s/it]

Error: invalid literal for int() with base 10: ''


 21%|██▏       | 107/500 [19:38<41:57,  6.41s/it]

Error: invalid literal for int() with base 10: ''


 22%|██▏       | 108/500 [19:44<40:37,  6.22s/it]

Error: invalid literal for int() with base 10: ''


 22%|██▏       | 109/500 [19:50<40:42,  6.25s/it]

In [None]:
random_list = [random.randint(1, 100) for _ in range(100)]
sorted_ai_list = sortLlama(random_list)
sorted_list = sorted(random_list)
print("Random List:", random_list)
print("Sorted List (AI):", sorted_ai_list)
print("Sorted List (Py):", sorted_list)

In [None]:
import matplotlib.pyplot as plt

# Extract the x, y, and color values from all_points
all_times = [point[0] for point in all_points]
all_lengths = [point[1] for point in all_points]
color = ['green' if point[2] else 'red' for point in all_points]

# Create the scatter plot
plt.scatter(all_lengths, all_times, c=color)
plt.xscale('log')
plt.yscale('log')
plt.xlabel('List Length')
plt.ylabel('Time Taken')

# Show the plot
plt.show()

In [None]:
import numpy as np

import matplotlib.pyplot as plt

# Extract the x and y values from all_points
x = [point[1] for point in all_points]
y = [point[0] for point in all_points]

# Fit a linear line
linear_coeffs = np.polyfit(x, y, 1)
linear_line = np.poly1d(linear_coeffs)

# Fit a log(n) line
log_coeffs = np.polyfit(np.log(x), y, 1)
log_line = np.poly1d(log_coeffs)

# Fit a n*log(n) line
nlog_coeffs = np.polyfit(x * np.log(x), y, 1)
nlog_line = np.poly1d(nlog_coeffs)

# Fit a n^2 line
n2_coeffs = np.polyfit(np.square(x), y, 1)
n2_line = np.poly1d(n2_coeffs)

# Plot the data points
plt.scatter(x, y, color=color)

# Plot the fitted lines
plt.plot(x, linear_line(x), label='Linear')
plt.plot(x, log_line(np.log(x)), label='Log(n)')
plt.plot(x, nlog_line(x * np.log(x)), label='n*log(n)')
plt.plot(x, n2_line(np.square(x)), label='n^2')

# Set the x-axis and y-axis labels
plt.xlabel('List Length')
plt.ylabel('Time Taken')
# Set the minimum y-axis value to 0
plt.ylim(bottom=0)
# Add a legend
plt.legend()

# Show the plot
plt.show()